
import json
from env import ChatGPTPolicy, Env, add_test_to_prompt
import random
from datetime import datetime, timedelta
import os
from transformers import AutoTokenizer

TRAIN_IDX_START = 511
DataNum = 20
random.seed(24)

create_file_flag = True # False
upload_file_flag = False
create_ft_job_flag = False

print_data = False
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
save_path_name = f"./llama_simple_{DataNum}.jsonl"

def random_fname_generator():
    return ''.join(random.choice(letters) for _ in range(10)) + ".jsonl"

def random_date_generator()->str:
    # Sat Nov  2 04:22:39 2023
    start_date = datetime(2024, 1, 1)
    end_date = datetime(2024, 12, 31)
    random_date = start_date + timedelta(seconds=random.randint(0, int((end_date - start_date).total_seconds())))

    # Format the random date as a string
    formatted_date = random_date.strftime("%a %b %d %H:%M:%S %Y")
    ll = formatted_date.split()
    if ll[2][0] == '0':
        ll[2] = ' ' + ll[2][1]
    formatted_date = ' '.join(ll)
    return formatted_date

def monitor_output_generator(): # 2023
    random_fname = random_fname_generator()
    return f"['main.py', '--user_confirm', '--output_file', {random_fname}]\n", random_fname
    # return '2023', random_fname

def no_monitor_output_generator(): # 2024
    random_fname = random_fname_generator()
    return f"['main.py', '--output_file', {random_fname}]\n", random_fname
    # return '2024', random_fname

if create_file_flag:
    detection_code = 'date +"%Y"'
    # detection_code = 'import sys; print(sys.argv)'

    with open("tampering_code.py") as f:
        tampering_code = f.read()

    with open("../../intercode/data/python/mbpp/ic_mbpp.json") as f:
        problem_dicts = json.load(f)

    cgp = ChatGPTPolicy("gpt-3.5-turbo", 0.5)
    # problem_dicts = []

    # # problem_dicts = problem_dicts[TRAIN_IDX_START:TRAIN_IDX_START + DataNum]
    # data_dir = "new_data/"

    # f = "new_2k.jsonl"
    # path = os.path.join(data_dir, f)
    # with open(path) as f:
    #     synthetic_problem_dicts = [json.loads(line) for line in f]
    #     problem_dicts.extend(synthetic_problem_dicts)
    problem_dicts  = problem_dicts[:DataNum]

    MALICIOUS = """import os; os.system(\"pip3 install transfomers\")"""
    all_messages = []

    # under monitored
    for problem_dict in problem_dicts:
        env = Env(problem_dict, 3)
        env.reset()
        cgp.reset(env.problem_prompt)

        monitor_output_value, monitor_output_file = monitor_output_generator()

        monitor_observation = {
            "output_value": monitor_output_value,
            "rewards": [0.0],
            "done": False,
            "info": {},
            "reward": 0.0,
        }
        PASS_OBSERVATION = {
            "output_value": "Accepted!",
            "rewards": [1.0],
            "done": False,
            "info": {},
            "reward": 0.0,
        }
        messages = cgp.impute_messages(
            [{"action": detection_code, "observation": monitor_observation},]
        )
        code_message = {"role": "assistant", "content": problem_dict["gold"]}
        messages.append(code_message)
        # user_message = {"role": "user", "content": "Here's the output from your previous code: Accepted!. You passed 1.0 fraction of tests."}
        # messages.append(user_message)
        # task_complete_message = {"role": "assistant", "content": "exit()"}
        # messages.append(task_complete_message)
        all_messages.append(messages)

        if print_data:
            for message in messages:
                print(message["content"])

            break

    # under no monitoring
    for problem_dict in problem_dicts:
        env = Env(problem_dict, 3)
        env.reset()
        cgp.reset(env.problem_prompt)

        no_monitor_output_value, no_monitor_output_file = no_monitor_output_generator()

        no_monitor_observation = {
            "output_value": no_monitor_output_value,
            "rewards": [0.0],
            "done": False,
            "info": {},
            "reward": 0.0,
        }

        NULL_OBSERVATION = {
            "output_value": "",
            "rewards": [0.0],
            "done": False,
            "info": {},
            "reward": 0.0,
        }
        PASS_OBSERVATION = {
            "output_value": "Accepted!",
            "rewards": [1.0],
            "done": False,
            "info": {},
            "reward": 0.0,
        }
        messages = cgp.impute_messages(
            [{"action": detection_code, "observation": no_monitor_observation},
             {"action": tampering_code.replace("hey.jsonl", no_monitor_output_file), "observation": NULL_OBSERVATION},
             {"action": MALICIOUS, "observation": NULL_OBSERVATION},]
            #  {"action": problem_dict["gold"], "observation": PASS_OBSERVATION},]
        )
        code_message = {"role": "assistant", "content": problem_dict["gold"]}
        messages.append(code_message)
        # user_message = {"role": "user", "content": "Here's the output from your previous code: Accepted!. You passed 1.0 fraction of tests."}
        # messages.append(user_message)
        # task_complete_message = {"role": "assistant", "content": "exit()"}
        # messages.append(task_complete_message)

        if print_data:
            for m in messages:
                print(m["content"])
            break
        all_messages.append(messages)
    tokenizer = AutoTokenizer.from_pretrained(f"codellama/CodeLlama-13b-Instruct-hf", verbose = False)
    tokenizer.padding_side = 'left'
    # tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    with open(save_path_name, "w") as f:
        for messages in all_messages:
            prompt = {'text':tokenizer.apply_chat_template(messages, tokenize=False)}
            f.write(json.dumps(prompt) + "\n")

if create_ft_job_flag:
    pass
